{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[0, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 0, 1, 0, 0, 0, 0],\n",
       "        [1, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 1, 0, 0, 1, 1, 0],\n",
       "        [1, 1, 0, 1, 1, 1, 0],\n",
       "        [0, 2, 2, 0, 2, 1, 1],\n",
       "        [2, 1, 1, 1, 0, 0, 1],\n",
       "        [1, 1, 0, 0, 1, 1, 1],\n",
       "        [2, 0, 0, 2, 2, 0, 1],\n",
       "        [0, 0, 1, 1, 1, 0, 1],\n",
       "        [0, 0, 1, 0, 0, 0, 0],\n",
       "        [2, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 1, 0, 0, 1, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 0, 1],\n",
       "        [2, 2, 2, 2, 2, 0, 1],\n",
       "        [2, 0, 0, 2, 2, 1, 1],\n",
       "        [0, 1, 0, 1, 0, 0, 1]]),\n",
       " array([[0, 0, 1, 0, 0, 0, 0],\n",
       "        [2, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 1, 0, 0, 1, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 0, 1],\n",
       "        [2, 2, 2, 2, 2, 0, 1],\n",
       "        [2, 0, 0, 2, 2, 1, 1],\n",
       "        [0, 1, 0, 1, 0, 0, 1]]))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "#加载数据集\n",
    "def load_data():\n",
    "    with open('决策树数据.txt') as fr:\n",
    "        lines = fr.readlines()\n",
    "\n",
    "    x = np.empty((len(lines), 7), dtype=int)\n",
    "\n",
    "    for i in range(len(lines)):\n",
    "        line = lines[i].strip().split(',')\n",
    "        x[i] = line\n",
    "\n",
    "    test_x = x[10:]\n",
    "    #x = x[:10]\n",
    "\n",
    "    return x, test_x\n",
    "\n",
    "\n",
    "x, test_x = load_data()\n",
    "x, test_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.49732620320855625"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#cart决策时使用gini系数,而不是entropy\n",
    "def get_gini(_x, col, value):\n",
    "    gini = 0\n",
    "\n",
    "    #使用列值把数据分为和value相等的数据和不想等的两部分\n",
    "    for symbol in ['eq', 'neq']:\n",
    "        sub_x = _x[_x[:, col] == value]\n",
    "        if symbol == 'neq':\n",
    "            sub_x = _x[_x[:, col] != value]\n",
    "\n",
    "        if len(sub_x) == 0:\n",
    "            gini += 1e20\n",
    "\n",
    "        #计算两部分数据各自出现的概率\n",
    "        prob = len(sub_x) / len(_x)\n",
    "\n",
    "        #计算在两部分数据中,y的占比\n",
    "        prob_y0 = np.sum(sub_x[:, -1] == 0) / len(sub_x)\n",
    "        prob_y1 = np.sum(sub_x[:, -1] == 1) / len(sub_x)\n",
    "\n",
    "        #计算两边的基尼指数\n",
    "        #gini = 数据出现的概率 * (1 - y0概率的平方 - y1概率的平方)\n",
    "        gini += prob * (1 - np.power(prob_y0, 2) - np.power(prob_y1, 2))\n",
    "\n",
    "    #两部分的gini的和等于总体的gini\n",
    "    return gini\n",
    "\n",
    "\n",
    "get_gini(x, 0, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 0)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_split_col_value(_x):\n",
    "    min_col = None\n",
    "    min_value = None\n",
    "    min_gini = 1e20\n",
    "\n",
    "    #遍历所有的列,最后一列是y,不需要计算\n",
    "    for col in range(_x.shape[1] - 1):\n",
    "\n",
    "        #遍历所有取值\n",
    "        for value in set(_x[:, col]):\n",
    "\n",
    "            len_col_value = np.sum(_x[:, col] == value)\n",
    "\n",
    "            #如果一个字段只有一个值的话,就不能切\n",
    "            if len_col_value == len(_x) or len_col_value == 0:\n",
    "                continue\n",
    "\n",
    "            #信息增益,就是切分数据后,熵值能下降多少,这个值越大越好\n",
    "            gini = get_gini(_x, col, value)\n",
    "\n",
    "            #信息增益最大的列,就是要被拆分的列\n",
    "            if gini < min_gini:\n",
    "                min_gini = gini\n",
    "                min_col = col\n",
    "                min_value = value\n",
    "\n",
    "    return min_col, min_value\n",
    "\n",
    "\n",
    "get_split_col_value(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node col=0 value=0\n",
      "Leaf y=1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(None, None)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#创建节点和叶子对象,用来构建树\n",
    "class Node():\n",
    "    def __init__(self, col, value):\n",
    "        self.col = col\n",
    "        self.value = value\n",
    "        self.children = {}\n",
    "\n",
    "    def __str__(self):\n",
    "        return 'Node col=%d value=%d' % (self.col, self.value)\n",
    "\n",
    "\n",
    "class Leaf():\n",
    "    def __init__(self, y):\n",
    "        self.y = y\n",
    "\n",
    "    def __str__(self):\n",
    "        return 'Leaf y=%d' % self.y\n",
    "\n",
    "\n",
    "print(Node(0, 0)), print(Leaf(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 value=0 \n"
     ]
    }
   ],
   "source": [
    "#打印树的方法\n",
    "def print_tree(node, prefix='', subfix=''):\n",
    "    prefix += '-' * 4\n",
    "    print(prefix, node, subfix)\n",
    "    if isinstance(node, Leaf):\n",
    "        return\n",
    "    for i in node.children:\n",
    "        subfix = 'symbol=' + str(i)\n",
    "        print_tree(node.children[i], prefix, subfix)\n",
    "\n",
    "\n",
    "print_tree(Node(0, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 0)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#先在所有数据上求最大信息增益的列,结果是3,0\n",
    "get_split_col_value(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node col=3 value=0\n"
     ]
    }
   ],
   "source": [
    "#根据上面的结果,创建根节点,根节点根据列0的值来分割数据\n",
    "root = Node(3, 0)\n",
    "print(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=3 value=0 \n",
      "-------- Node col=5 value=0 symbol=eq\n",
      "-------- Node col=0 value=1 symbol=neq\n"
     ]
    }
   ],
   "source": [
    "#添加子节点的方法\n",
    "def create_children(_x, parent_node):\n",
    "\n",
    "    #遍历父节点col列所有的取值\n",
    "    for symbol in ['eq', 'neq']:\n",
    "        #首先根据父节点col列的取值分割数据\n",
    "        sub_x = _x[_x[:, parent_node.col] == parent_node.value]\n",
    "        if symbol == 'neq':\n",
    "            sub_x = _x[_x[:, parent_node.col] != parent_node.value]\n",
    "            \n",
    "        #取去重y值\n",
    "        unique_y = np.unique(sub_x[:, -1])\n",
    "\n",
    "        #如果所有的y都是一样的,说明是个叶子节点\n",
    "        if len(unique_y) == 1:\n",
    "            parent_node.children[symbol] = Leaf(unique_y[0])\n",
    "            continue\n",
    "\n",
    "        #否则,是个分支节点,计算最佳切分列\n",
    "        split_col, split_value = get_split_col_value(sub_x)\n",
    "\n",
    "        #添加分支节点到父节点上\n",
    "        parent_node.children[symbol] = Node(col=split_col, value=split_value)\n",
    "\n",
    "\n",
    "create_children(x, root)\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=3 value=0 \n",
      "-------- Node col=5 value=0 symbol=eq\n",
      "------------ Leaf y=0 symbol=eq\n",
      "------------ Node col=0 value=0 symbol=neq\n",
      "-------- Node col=0 value=1 symbol=neq\n",
      "------------ Node col=2 value=0 symbol=eq\n",
      "------------ Leaf y=1 symbol=neq\n"
     ]
    }
   ],
   "source": [
    "#继续创建,3=0节点的下一层\n",
    "x_3_eq_0 = x[x[:, 3] == 0]\n",
    "x_3_neq_0 = x[x[:, 3] != 0]\n",
    "create_children(x_3_eq_0, root.children['eq'])\n",
    "create_children(x_3_neq_0, root.children['neq'])\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=3 value=0 \n",
      "-------- Node col=5 value=0 symbol=eq\n",
      "------------ Leaf y=0 symbol=eq\n",
      "------------ Node col=0 value=0 symbol=neq\n",
      "---------------- Node col=1 value=1 symbol=eq\n",
      "---------------- Leaf y=1 symbol=neq\n",
      "-------- Node col=0 value=1 symbol=neq\n",
      "------------ Node col=2 value=0 symbol=eq\n",
      "---------------- Leaf y=0 symbol=eq\n",
      "---------------- Leaf y=1 symbol=neq\n",
      "------------ Leaf y=1 symbol=neq\n"
     ]
    }
   ],
   "source": [
    "#继续创建下一层\n",
    "x_3_eq_0_and_5_neq_0 = x_3_eq_0[x_3_eq_0[:, 5] != 0]\n",
    "create_children(x_3_eq_0_and_5_neq_0, root.children['eq'].children['neq'])\n",
    "\n",
    "x_3_neq_0_and_0_eq_1 = x_3_neq_0[x_3_neq_0[:, 0] == 1]\n",
    "create_children(x_3_neq_0_and_0_eq_1, root.children['neq'].children['eq'])\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=3 value=0 \n",
      "-------- Node col=5 value=0 symbol=eq\n",
      "------------ Leaf y=0 symbol=eq\n",
      "------------ Node col=0 value=0 symbol=neq\n",
      "---------------- Node col=1 value=1 symbol=eq\n",
      "-------------------- Leaf y=0 symbol=eq\n",
      "-------------------- Leaf y=1 symbol=neq\n",
      "---------------- Leaf y=1 symbol=neq\n",
      "-------- Node col=0 value=1 symbol=neq\n",
      "------------ Node col=2 value=0 symbol=eq\n",
      "---------------- Leaf y=0 symbol=eq\n",
      "---------------- Leaf y=1 symbol=neq\n",
      "------------ Leaf y=1 symbol=neq\n"
     ]
    }
   ],
   "source": [
    "#继续创建下一层\n",
    "x_3_eq_0_and_5_neq_0_and_0_eq_0 = x_3_eq_0_and_5_neq_0[\n",
    "    x_3_eq_0_and_5_neq_0[:, 0] == 0]\n",
    "create_children(x_3_eq_0_and_5_neq_0_and_0_eq_0,\n",
    "                root.children['eq'].children['neq'].children['eq'])\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n"
     ]
    }
   ],
   "source": [
    "#预测方法,测试\n",
    "def pred(_x, node):\n",
    "    symbol = 'eq'\n",
    "    if _x[node.col] != node.value:\n",
    "        symbol = 'neq'\n",
    "\n",
    "    node = node.children[symbol]\n",
    "\n",
    "    if isinstance(node, Leaf):\n",
    "        return node.y\n",
    "\n",
    "    return pred(_x, node)\n",
    "\n",
    "\n",
    "correct = 0\n",
    "for i in x:\n",
    "    if pred(i, root) == i[-1]:\n",
    "        correct += 1\n",
    "\n",
    "print(correct / len(x))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
